{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "DATA = untar_data(URLs.IMAGENETTE_160)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "\n",
       "Train: LabelList (3798 items)\n",
       "x: ImageList\n",
       "Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160)\n",
       "y: CategoryList\n",
       "n03028079,n03028079,n03028079,n03028079,n03028079\n",
       "Path: /home/user/.fastai/data/imagenette-160;\n",
       "\n",
       "Valid: LabelList (161 items)\n",
       "x: ImageList\n",
       "Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160)\n",
       "y: CategoryList\n",
       "n03028079,n03028079,n03028079,n03028079,n03028079\n",
       "Path: /home/user/.fastai/data/imagenette-160;\n",
       "\n",
       "Test: None"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src = (ImageList.from_folder(DATA).filter_by_rand(0.3, seed=42)\n",
    "                .split_by_folder(valid='val')\n",
    "                .label_from_folder()\n",
    "                .transform(([flip_lr(p=0.5)], []), size=160))\n",
    "data = (src.databunch(bs=64, num_workers=6)\n",
    "           .normalize(imagenet_stats))\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from fastai import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,\n",
    "               norm_type:Optional[NormType]=NormType.Batch,  use_activ:bool=True, activ_fn:Callable=None, leaky:float=None,\n",
    "               transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False):\n",
    "    \"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.\"\n",
    "    activ_fn = ifnone(activ_fn, partial(relu, inplace=True, leaky=leaky))\n",
    "    if padding is None: padding = (ks-1)//2 if not transpose else 0\n",
    "    bn = norm_type in (NormType.Batch, NormType.BatchZero)\n",
    "    if bias is None: bias = not bn\n",
    "    conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d\n",
    "    conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)\n",
    "    if   norm_type==NormType.Weight:   conv = weight_norm(conv)\n",
    "    elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n",
    "    layers = [conv]\n",
    "    if use_activ: layers.append(activ_fn())\n",
    "    if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))\n",
    "    if self_attention: layers.append(SelfAttention(nf))\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def simple_cnn(data, actns:Collection[int], kernel_szs:Collection[int]=None,\n",
    "               strides:Collection[int]=None, bn=False, activ_fn=None,\n",
    "               lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,\n",
    "               concat_pool:bool=True, bn_final:bool=False) -> nn.Sequential:\n",
    "    \"CNN with `conv_layer` defined by `actns`, `kernel_szs` and `strides`, plus batchnorm if `bn`.\"\n",
    "    nl = len(actns)-1\n",
    "    kernel_szs = ifnone(kernel_szs, [3]*nl)\n",
    "    strides    = ifnone(strides   , [2]*nl)\n",
    "    layers = [conv_layer(actns[i], actns[i+1], kernel_szs[i], stride=strides[i],\n",
    "              norm_type=(NormType.Batch if bn and i<(len(strides)-1) else None), activ_fn=activ_fn) for i in range_of(strides)]\n",
    "    nf_head = actns[-1]  * (2 if concat_pool else 1)\n",
    "    head = create_head(nf_head, data.c, lin_ftrs=lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)\n",
    "    return nn.Sequential(*layers, head)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "actns = [3,64,64,128,128,256,256,512,512]\n",
    "strides = [1,2]*(len(actns)//2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Relu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (1): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (2): Sequential(\n",
       "    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (3): Sequential(\n",
       "    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (4): Sequential(\n",
       "    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (5): Sequential(\n",
       "    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (6): Sequential(\n",
       "    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (7): Sequential(\n",
       "    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): ReLU(inplace=True)\n",
       "  )\n",
       "  (8): Sequential(\n",
       "    (0): AdaptiveConcatPool2d(\n",
       "      (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "      (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "    )\n",
       "    (1): Flatten()\n",
       "    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Dropout(p=0.25, inplace=False)\n",
       "    (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "    (5): ReLU(inplace=True)\n",
       "    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (7): Dropout(p=0.5, inplace=False)\n",
       "    (8): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mdl_relu = simple_cnn(data, actns=actns, strides=strides)\n",
    "mdl_relu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "lrn = Learner(data, mdl_relu, metrics=[accuracy,top_k_accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>top_k_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.129265</td>\n",
       "      <td>2.851966</td>\n",
       "      <td>0.217391</td>\n",
       "      <td>0.614907</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.975083</td>\n",
       "      <td>2.081112</td>\n",
       "      <td>0.267081</td>\n",
       "      <td>0.739130</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.831574</td>\n",
       "      <td>1.718789</td>\n",
       "      <td>0.403727</td>\n",
       "      <td>0.844720</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.639187</td>\n",
       "      <td>1.452828</td>\n",
       "      <td>0.496894</td>\n",
       "      <td>0.888199</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.453105</td>\n",
       "      <td>1.301081</td>\n",
       "      <td>0.546584</td>\n",
       "      <td>0.919255</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lrn.fit_one_cycle(5, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this Learner object self-destroyed - it still exists, but no longer usable\n"
     ]
    }
   ],
   "source": [
    "lrn.destroy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "class Mish(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x * torch.tanh(F.softplus(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (1): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (2): Sequential(\n",
       "    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (3): Sequential(\n",
       "    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (4): Sequential(\n",
       "    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (5): Sequential(\n",
       "    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (6): Sequential(\n",
       "    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (7): Sequential(\n",
       "    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): Mish()\n",
       "  )\n",
       "  (8): Sequential(\n",
       "    (0): AdaptiveConcatPool2d(\n",
       "      (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "      (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "    )\n",
       "    (1): Flatten()\n",
       "    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Dropout(p=0.25, inplace=False)\n",
       "    (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "    (5): ReLU(inplace=True)\n",
       "    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (7): Dropout(p=0.5, inplace=False)\n",
       "    (8): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mdl_mish = simple_cnn(data, actns=actns, strides=strides, activ_fn=Mish)\n",
    "mdl_mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "lrn = Learner(data, mdl_mish, metrics=[accuracy,top_k_accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>top_k_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.084685</td>\n",
       "      <td>2.054669</td>\n",
       "      <td>0.291925</td>\n",
       "      <td>0.782609</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.830139</td>\n",
       "      <td>1.782138</td>\n",
       "      <td>0.397516</td>\n",
       "      <td>0.838509</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.590294</td>\n",
       "      <td>1.337767</td>\n",
       "      <td>0.540373</td>\n",
       "      <td>0.894410</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.334887</td>\n",
       "      <td>1.116282</td>\n",
       "      <td>0.639752</td>\n",
       "      <td>0.913043</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.105988</td>\n",
       "      <td>0.972040</td>\n",
       "      <td>0.695652</td>\n",
       "      <td>0.944099</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lrn.fit_one_cycle(5, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this Learner object self-destroyed - it still exists, but no longer usable\n"
     ]
    }
   ],
   "source": [
    "lrn.destroy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Mish CUDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from mish_cuda import MishCuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (1): Sequential(\n",
       "    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (2): Sequential(\n",
       "    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (3): Sequential(\n",
       "    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (4): Sequential(\n",
       "    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (5): Sequential(\n",
       "    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (6): Sequential(\n",
       "    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (7): Sequential(\n",
       "    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    (1): MishCuda()\n",
       "  )\n",
       "  (8): Sequential(\n",
       "    (0): AdaptiveConcatPool2d(\n",
       "      (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "      (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "    )\n",
       "    (1): Flatten()\n",
       "    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): Dropout(p=0.25, inplace=False)\n",
       "    (4): Linear(in_features=1024, out_features=512, bias=True)\n",
       "    (5): ReLU(inplace=True)\n",
       "    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (7): Dropout(p=0.5, inplace=False)\n",
       "    (8): Linear(in_features=512, out_features=10, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mdl_mish = simple_cnn(data, actns=actns, strides=strides, activ_fn=MishCuda)\n",
    "mdl_mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "lrn = Learner(data, mdl_mish, metrics=[accuracy,top_k_accuracy])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>top_k_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.086314</td>\n",
       "      <td>2.356230</td>\n",
       "      <td>0.211180</td>\n",
       "      <td>0.739130</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.844088</td>\n",
       "      <td>2.806577</td>\n",
       "      <td>0.267081</td>\n",
       "      <td>0.677019</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.593210</td>\n",
       "      <td>1.359999</td>\n",
       "      <td>0.552795</td>\n",
       "      <td>0.919255</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.329840</td>\n",
       "      <td>1.198766</td>\n",
       "      <td>0.608696</td>\n",
       "      <td>0.944099</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.118359</td>\n",
       "      <td>0.987282</td>\n",
       "      <td>0.670807</td>\n",
       "      <td>0.950311</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lrn.fit_one_cycle(5, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this Learner object self-destroyed - it still exists, but no longer usable\n"
     ]
    }
   ],
   "source": [
    "lrn.destroy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Stability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from mish_cuda import mish_backward\n",
    "def stress_bwd():\n",
    "    grad_out = torch.ones(1000,100)\n",
    "    for _ in progress_bar(range(1000)):\n",
    "        inp = torch.randn(1000,100) + torch.randint(-1000, 1000, (1000,1)).float()\n",
    "        y = mish_backward(inp, grad_out)\n",
    "        ni = (y == y.new_full((1,), np.inf)).sum()\n",
    "        nn = (y == y.new_full((1,), np.nan)).sum()\n",
    "        if ni > 0 or nn > 0:\n",
    "            print(F\"Found non-finite: {ni} inf; {nn} nan\")\n",
    "            return (inp,grad_out,y)\n",
    "    print(\"No non-finites found\")\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='1000' class='' max='1000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [1000/1000 00:03<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No non-finites found\n"
     ]
    }
   ],
   "source": [
    "bad = stress_bwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-fastai]",
   "language": "python",
   "name": "conda-env-.conda-fastai-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
